import torch
import torch.nn as nn
import numpy as np
from PIL import Image
import os
import logging
logger = logging.getLogger(__name__)
from .utils import CIFAR100_TRAIN_TRAINSFORM, CIFAR100_EVAL_TRAINSFORM, unpickle, get_unsupervised_transform, CIFAR100_NORMALIZE, other_class, build_for_cifar100, multiclass_noisify


class CIFAR100(torch.utils.data.Dataset):
    def __init__(self, train, data_path, unsupervised_transform=False):
        super(CIFAR100, self).__init__()
        if train:
            d = unpickle(os.path.join(data_path, 'cifar-100-python/train'))
        else:
            d = unpickle(os.path.join(data_path, 'cifar-100-python/test'))
        self.data = d[b'data']
        self.targets = d[b'fine_labels']

        self.data = self.data.reshape(-1, 3, 32, 32).transpose(0, 2, 3, 1)

        if unsupervised_transform:
            self.transform = get_unsupervised_transform(normalize=CIFAR100_NORMALIZE)
        else:
            self.transform = CIFAR100_TRAIN_TRAINSFORM if train else CIFAR100_EVAL_TRAINSFORM

    def __getitem__(self, idx):
        data = self.data[idx]
        label = self.targets[idx]
        image = Image.fromarray(data)
        image = self.transform(image)
        return image, label

    def __len__(self):
        return len(self.targets)



class CIFAR100_TwoCrops(CIFAR100):
    def __init__(self, train, data_path):
        super(CIFAR100_TwoCrops, self).__init__(train, data_path)
        self.transform = get_unsupervised_transform(normalize=CIFAR100_NORMALIZE)
        self.transform_ = get_unsupervised_transform(normalize=CIFAR100_NORMALIZE)

    def __getitem__(self, idx):
        data = self.data[idx]
        label = self.targets[idx]
        image = Image.fromarray(data)
        image1 = self.transform_(image)
        image2 = self.transform(image)
        return (image1, image2), label


class NoisyCIFAR100(CIFAR100):
    def __init__(self, train, data_path, unsupervised_transform=False, noise_rate=0.0, is_asym=False, seed=0):
        super(NoisyCIFAR100, self).__init__(train, data_path, unsupervised_transform=unsupervised_transform)

        np.random.seed(seed)
        if is_asym:
            """mistakes are inside the same superclass of 10 classes, e.g. 'fish'
            """
            nb_classes = 100
            P = np.eye(nb_classes)
            n = noise_rate
            nb_superclasses = 20
            nb_subclasses = 5

            if n > 0.0:
                for i in np.arange(nb_superclasses):
                    init, end = i * nb_subclasses, (i+1) * nb_subclasses
                    P[init:end, init:end] = build_for_cifar100(nb_subclasses, n)

                    y_train_noisy = multiclass_noisify(np.array(self.targets), P=P, random_state=seed)
                    actual_noise = (y_train_noisy != np.array(self.targets)).mean()
                assert actual_noise > 0.0
                print('Actual noise %.2f' % actual_noise)
                self.targets = y_train_noisy.tolist()
        elif noise_rate > 0:
            n_samples = len(self.targets)
            n_noisy = int(noise_rate * n_samples)
            logger.info("%d Noisy samples" % (n_noisy))
            class_index = [np.where(np.array(self.targets) == i)[0] for i in range(100)]
            class_noisy = int(n_noisy / 100)
            noisy_idx = []
            for d in range(100):
                noisy_class_index = np.random.choice(class_index[d], class_noisy, replace=False)
                noisy_idx.extend(noisy_class_index)
                logger.info("Class %d, number of noisy %d" % (d, len(noisy_class_index)))
            for i in noisy_idx:
                self.targets[i] = other_class(n_classes=100, current_class=self.targets[i])
        logger.info("Pring noisy label generation statistics:")
        for i in range(100):
            n_noisy = np.sum(np.array(self.targets) == i)
            logger.info("Noisy class %s, has %s samples." % (i, n_noisy))

class NoisyCIFAR100_TwoCrops(NoisyCIFAR100):
    def __init__(self, train, data_path, noise_rate=0.0, is_asym=False):
        super(NoisyCIFAR100_TwoCrops, self).__init__(train, data_path, noise_rate=noise_rate, is_asym=is_asym)
        self.transform = get_unsupervised_transform(normalize=CIFAR100_NORMALIZE)
        self.transform_ = get_unsupervised_transform(normalize=CIFAR100_NORMALIZE)
    
    def __getitem__(self, idx):
        data = self.data[idx]
        label = self.targets[idx]
        image = Image.fromarray(data)
        image1 = self.transform_(image)
        image2 = self.transform(image)
        return (image1, image2), label